import base64
import io
from PIL import Image
import re
import json
from typing import List, Dict, Any, Union
from io import BytesIO
import tempfile
import os
import glob

def load_and_encode_image(image_input: Union[str, bytes]) -> str:
    if isinstance(image_input, bytes):
        return base64.b64encode(image_input).decode("utf-8")
    elif isinstance(image_input, str):
        if os.path.isfile(image_input):
            with open(image_input, "rb") as f:
                return base64.b64encode(f.read()).decode("utf-8")
        elif len(image_input) < 500 and not any(c in image_input for c in ["\n", "\r"]):
            # Too short to be base64, assume bad path
            raise FileNotFoundError(f"Invalid image path: {image_input}")
        else:
            # assume it's base64 string already
            return image_input
    else:
        raise ValueError("Unsupported image input format.")
def cleanup_temp_images():
    for file in glob.glob("/tmp/tmp*.png"):
        try:
            os.remove(file)
        except:
            pass

def encode_image_for_qwen(image_path: str) -> Dict[str, str]:
    with Image.open(image_path) as img:
        max_size = 512
        if max(img.width, img.height) > max_size:
            scale = max_size / max(img.width, img.height)
            img = img.resize((int(img.width * scale), int(img.height * scale)), Image.Resampling.LANCZOS)
        img = img.convert("RGB")
        with io.BytesIO() as buffer:
            img.save(buffer, format="PNG")
            encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
    return {"type": "image", "image": f"data:image/png;base64,{encoded}"}

def encode_image(image_path):
    with Image.open(image_path) as img:
        max_size = 512
        if max(img.width, img.height) > max_size:
            scale = max_size / max(img.width, img.height)
            new_size = (int(img.width * scale), int(img.height * scale))
            img = img.resize(new_size, Image.Resampling.LANCZOS)  
        rgb_img = img.convert('RGB')
        with io.BytesIO() as byte_stream:
            rgb_img.save(byte_stream, format='PNG')
            byte_stream.seek(0)
            encoded_string = base64.b64encode(byte_stream.read()).decode('utf-8')
            return encoded_string

# Function to extract JSON from LLM response
def extract_json(text):
    """Extracts JSON object from a given text response using regex."""
    try:
        match = re.search(r'\{.*\}', text, re.DOTALL)  # Find JSON-like structure
        if match:
            return json.loads(match.group(0))  # Convert to dictionary
        return None  # No valid JSON found
    except json.JSONDecodeError:
        return None  # Invalid JSON structure

def extract_json1(response_str):
    match = re.search(r"```json\s*(.*?)\s*```", response_str, re.DOTALL)
    if match:
        json_str = match.group(1)
    else:
        json_str = response_str.strip()
    
    try:
        return json.loads(json_str)
    except json.JSONDecodeError as e:
        print("JSON decode error:", e)
        return None

def print_stream(stream):
    for s in stream:
        message = s["messages"][-1]
        if isinstance(message, tuple):
                print(message)
        else:
            message.pretty_print()
            
def save_all_states_to_json(states, path):
    cleaned_states = []
    for s in states:
        state = dict(s)  
        if "state" in state and "messages" in state["state"]:
            del state["state"]["messages"]
        cleaned_states.append(state)

    with open(path, 'w') as f:
        json.dump(cleaned_states, f, indent=2, ensure_ascii=False)

def is_prediction_correct(output: str, correct_choice: str, correct_ans: str) -> bool:
    resp = output.strip()
    if resp == correct_choice:
        return True
    if f"{correct_choice}." in resp:
        return True
    if resp == correct_ans:
        return True
    for line in resp.splitlines():
        if line.strip() == correct_choice:
            return True
    return False

def answers_match(a: str, b: str, choices: List[str]) -> bool:
    if is_prediction_correct(a, b, choices[ord(b.upper()) - 65]):
        return True
    if is_prediction_correct(b, a, choices[ord(a.upper()) - 65]):
        return True
    if a.strip().lower() == b.strip().lower():
        return True
    return False